import scipy.io as scio
import torchvision
import torch
from torch import nn 
from torch.utils.data import Dataset,DataLoader,TensorDataset
from torchvision import datasets, transforms
import time
import numpy as np 
import pandas as pd
import random
import math
from torch.nn import functional as F
import csv

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import sys, random, time
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
import csv



def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

seedsss=200
setup_seed(seedsss)


filename_list_whole=["../../ref_traj/"+'M'+"_reftraj.mat" ]*50 + ["../../ref_traj/"+'E'+"_reftraj.mat" ]*50 + ["../../ref_traj/"+'T'+"_reftraj.mat" ]*50 + ["../../ref_traj/"+'A'+"_reftraj.mat" ]*51
center_list_whole=np.random.normal(0, 1, [len(filename_list_whole),2])

batch_size_K = 400
meta_lambda=100.0
n_epochs = 40

redius=2.0
less=False
weight=500.0
softplus_para=200.0
    
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.params = torch.nn.ParameterList([
                    torch.Tensor(128, 8).uniform_(-1./math.sqrt(8), 1./math.sqrt(8)).requires_grad_(),
                    torch.Tensor(128).zero_().requires_grad_(),

                    torch.Tensor(128, 128).uniform_(-1./math.sqrt(128), 1./math.sqrt(128)).requires_grad_(),
                    torch.Tensor(128).zero_().requires_grad_(),

                    torch.Tensor(128, 128).uniform_(-1./math.sqrt(128), 1./math.sqrt(128)).requires_grad_(),
                    torch.Tensor(128).zero_().requires_grad_(),

                    torch.Tensor(2, 128).uniform_(-1./math.sqrt(128), 1./math.sqrt(128)).requires_grad_(),
                    torch.Tensor(2).zero_().requires_grad_(),
                ])

    def dense(self, x, params):
        y = F.linear(x, params[0], params[1])
        y = F.relu(y)

        y = F.linear(y, params[2], params[3])
        y = F.relu(y)

        y = F.linear(y, params[4], params[5])
        y = F.relu(y)

        y = F.linear(y, params[6], params[7])

        return y

    def input_process(self, x):
        x2=torch.pow(x, 2)
        x3=torch.pow(x, 3)
        x4=torch.pow(x, 4)
        x_sin=torch.sin(x*3.14)
        x_cos=torch.cos(x*3.14)
        x_sin_2=torch.sin(2*x*3.14)
        x_cos_2=torch.cos(2*x*3.14)
        return torch.cat((x,x2,x3,x4,x_sin,x_cos,x_sin_2,x_cos_2), 1)

    def forward(self, x, params):
        v = torch.ones(x.shape,dtype=torch.float).to(device) 
        position=self.dense(self.input_process(x), params)*10.0 
        position1,position2=position.split([1,1],dim=1) 
        vel1=torch.autograd.grad(position1,x,v,retain_graph=True, create_graph=True)[0] 
        vel2=torch.autograd.grad(position2,x,v,retain_graph=True, create_graph=True)[0] 
        return torch.cat((position1,position2,vel1,vel2), 1) 
    
    def forward1(self, x, params):
        position=self.dense(self.input_process(x), params)*10.0
        return position

def my_mse_loss(outputs, Q, Sigma):
    a=outputs - Q
    a=torch.reshape(a,(-1,4,1))
    b=torch.reshape(a,(-1,1,4))
    #print(a.shape)
    #print(b.shape)
    #print(Sigma.shape)
    return torch.mean(torch.matmul(torch.matmul(b,torch.inverse(Sigma)),a))

def constraint_voilations(outputs, center=[0.0,0.0], less=less, redius=redius, weight=weight):
    position,vel=outputs.split([2,2],dim=1)
    center_tensor=torch.tensor(center, dtype= torch.float).to(device) 
    constraint_voilations=0.0
    if less:
        constraint_voilations= (F.softplus((torch.norm(position-center_tensor,dim=1)- redius),softplus_para)-0.001)*weight
    else:
        constraint_voilations= (F.softplus((-torch.norm(position-center_tensor,dim=1)+ redius),softplus_para)-0.001)*weight
    return torch.mean(constraint_voilations)

def bias_reg(params,meta_parameter, lambada=meta_lambda):
    theta_prime = [(params[i] - meta_parameter[i]) for i in range(len(params))]
    bias_reg_loss=0.0
    for i in range(len(params)):
        bias_reg_loss+=torch.norm(theta_prime[i])*torch.norm(theta_prime[i])
    return bias_reg_loss*lambada

def adjust_learning_rate(optimizer, epoch, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr 


def test_result(round, expert_id):

    t_data_test_list=[]
    y_data_test_list=[]
    sigma_data_test_list=[]

    test_file_name_list=filename_list_whole[round:round+1]
    task_test_num=len(test_file_name_list)

    #center_list_train=center_list_whole[0:round,:]
    center_list_test=center_list_whole[round:round+1,:]

    for filename in test_file_name_list:
        t_data=[]
        y_data=[]
        sigma_data=[]
        file_data=scio.loadmat(filename)['refTraj'][0]
        for data in file_data:
            t_data.append([data[0][0][0]-1.0])
            y_data.append([data[1][0][0],data[1][1][0],data[1][2][0],data[1][3][0]])
            sigma_data.append(data[2]+0.001*np.identity(4))
        t_data_test_list.append(np.array(t_data))
        y_data_test_list.append(np.array(y_data))
        sigma_data_test_list.append(np.array(sigma_data))


    for num_task in range(len(t_data_test_list)):

        model = torch.load('./pkl/model_meta_'+str(round)+'_expert_'+str(expert_id)+'_'+str(seedsss)+'.pkl') 
        model = model.to(device)
        model_meta=torch.load('./pkl/model_meta_'+str(round)+'_expert_'+str(expert_id)+'_'+str(seedsss)+'.pkl') 
        meta_parameter=model_meta.params
        for i in range(len(model_meta.params)):
            meta_parameter[i].requires_grad = False 

        """
        learning_rate=0.00006
        optimizer = torch.optim.SGD(model.params,lr=learning_rate,weight_decay=0.00001)
        lambada= 1.0
        lr_lamabada=0.04
        """
        learning_rate0=0.001
        optimizer = torch.optim.Adam(model.params,lr=learning_rate0,weight_decay=0.00001)
        lambada= 1.0
        lr_lamabada=0.04

        
        data_loader_train = torch.utils.data.DataLoader(TensorDataset(torch.tensor(t_data_test_list[num_task]).float().requires_grad_(),torch.tensor(y_data_test_list[num_task]).float(),torch.tensor(sigma_data_test_list[num_task]).float()),shuffle = True, batch_size = batch_size_K)
        data_loader_test = torch.utils.data.DataLoader(TensorDataset(torch.tensor(t_data_test_list[num_task]).float().requires_grad_(),torch.tensor(y_data_test_list[num_task]).float(),torch.tensor(sigma_data_test_list[num_task]).float()),shuffle = False, batch_size = 400)
        (step_train, data_train_now) = list(enumerate(data_loader_train))[0]
        (step_test, data_test_now) = list(enumerate(data_loader_test))[0]

        for epoch in range(n_epochs):

            (features, labels, sigmas)=data_train_now
            features = features.to(device)
            labels = labels.to(device)
            sigmas=sigmas.to(device)
            outputs = model(features, model.params)
            loss_train = my_mse_loss(outputs, labels,sigmas)+bias_reg(model.params,meta_parameter)
            
            (features_constraint, labels_constraint, sigmas_constraint)=data_test_now
            features_constraint = features_constraint.to(device)
            outputs1 = model(features_constraint, model.params)
            loss_train+=constraint_voilations(outputs1,center=center_list_test[num_task])*lambada
            
            optimizer.zero_grad()
            loss_train.backward(retain_graph=True)
            optimizer.step()

            outputs1 = model(features_constraint, model.params)
            gradient_lambada=constraint_voilations(outputs1,center=center_list_test[num_task]).item()
            if gradient_lambada>0.5:
                gradient_lambada=0.5
            if gradient_lambada<0:
                gradient_lambada=-0.5
                
            lambada+=lr_lamabada*gradient_lambada
            if lambada<0: 
                lambada=0.0 
            
            #print(lambada)

            (features, labels, sigmas)=data_test_now
            features = features.to(device)
            labels = labels.to(device)
            sigmas=sigmas.to(device)
            outputs = model(features, model.params)
            loss_test = my_mse_loss(outputs, labels,sigmas)
            constraint_test=constraint_voilations(outputs,center=center_list_test[num_task])
        
        print(f"round = {round}")
        print(f'epoch = {epoch+1}, step = {step_train+1}, train loss = {loss_train.item() / 1:.6f}, reg loss = {bias_reg(model.params,meta_parameter).item():.6f}')
        print(f'epoch = {epoch+1}, test mse loss = {loss_test.item() :.6f}, test constraint loss = {constraint_test.item()  :.6f}')
    return loss_test.item(),constraint_test.item()

class Meta:
    """
        An abstract class for meta-algorithm: AdaNormalHedge

    Args:
        prob (numpy.ndarray): Initial probability over the base-learners.
    """

    def __init__(self, prob: np.ndarray, N: int):
        self._prob = prob
        self._init_prob = self._prob.copy()
        self.t = 0
        self._R = np.zeros(N)
        self._C = np.zeros(N)
        self._w = np.zeros(N)

    def _Phi(self, R, C):
        R_plus = np.maximum(0, R)
        return np.exp(np.square(R_plus) / (3 * C))
	
    def _w_func(self, R, C):
        return 0.5 * (self._Phi(R + 1, C + 1) - self._Phi(R - 1, C - 1))

    def update_prob(self, loss_bases: np.ndarray, loss_meta):
        self.R += loss_meta - loss_bases
        self.C += np.abs(loss_meta - loss_bases)
        self.w = self._w_func(self.R, self.C)
        self.prob = self.init_prob * self.w
        self.prob /= np.sum(self.prob)

    def update_active_state(self, active_state):
        self._active_state = active_state
        self._active_index = np.where(active_state > 0)[0]
        re_init_idx = np.where(self._active_state == 2)[0]
        self._R[re_init_idx], self._C[re_init_idx] = 0, 0
        self._w[re_init_idx] = self._w_func(0, 0)
        self._prob[re_init_idx] = self._init_prob[re_init_idx] * self._w[re_init_idx]
        self.prob /= np.sum(self.prob)

    def sample_expert(self):
        return np.random.choice(self._active_index, p=self.prob)
        
    @property
    def w(self):
        return self._w[self._active_index]

    @w.setter
    def w(self, w):
        self._w[self._active_index] = w

    @property
    def R(self):
        return self._R[self._active_index]

    @R.setter
    def R(self, R):
        self._R[self._active_index] = R

    @property
    def C(self):
        return self._C[self._active_index]

    @C.setter
    def C(self, C):
        self._C[self._active_index] = C
    
    @property
    def prob(self):
        return self._prob[self._active_index]

    @prob.setter
    def prob(self, prob):
        self._prob[self._active_index] = prob

    @property
    def init_prob(self):
        """Get the initial probability over the current alive base-learners."""
        return self._init_prob[self._active_index]

class Schedule:
    """ 
        abstract class for scheduler. 
    
    """
    def __init__(self, expert_num: int):
        self.active_state = np.zeros(expert_num)
        self.t = 0
        self.exp_num = expert_num
        self.next_k = np.zeros(expert_num)
        self.time_checkpoint = np.zeros(expert_num, dtype=np.int_)
        for k in range(self.exp_num):
            self.next_k[k] = 2**k - 1

    def update_t(self):
        for k in range(self.exp_num):
            if self.active_state[k] == 2:
                self.active_state[k] = 1
            if self.t == self.next_k[k]:
                self.time_checkpoint[k] = self.t
                self.active_state[k] = 2
                self.next_k[k] = self.next_k[k] + 2**k 
        self.t = self.t + 1

    def get_active_state(self):
        return self.active_state

if __name__ == "__main__":
    round=199
    csvfile = open("result"+str(seedsss)+".csv", "w", encoding='utf-8', newline='')
    writer = csv.writer(csvfile)
    # construct meta
    expert_num = math.floor(math.log(round+1,2))+1
    meta = Meta(np.ones(expert_num), expert_num)
    scheduler = Schedule(expert_num)
    loss_scale = 0.01
    loss_test = np.zeros((round,expert_num))
    constraint_test = np.zeros((round,expert_num))
    for i in range(round):
        inst_loss = np.zeros(expert_num)

        # schedule
        scheduler.update_t()
        active_state = scheduler.get_active_state()
        meta.update_active_state(active_state)
        
        # submit the model to the environment
        model_submit_id = meta.sample_expert()

        for j in range(expert_num):
            if active_state[j]:
                loss_test[i,j],constraint_test[i,j]=test_result(i+1, j)
            else:
                loss_test[i,j],constraint_test[i,j] = 0, 0

        inst_loss = loss_test[i,:]            
        inst_loss = inst_loss[np.where(active_state > 0)[0]]    
        inst_loss = inst_loss * loss_scale
        meta_loss = np.mean(inst_loss)
        meta.update_prob(inst_loss, meta_loss)
        writer.writerow([str(loss_test[i,model_submit_id]),str(constraint_test[i,model_submit_id])])

    csvfile.close()
    